# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Handles function calls, by generating compiled function names and calls.

Note: this transformer does not rename the top level object being converted;
that is the caller's responsibility.

Requires function_scopes.
"""

import gast

from nvidia.dali._autograph.core import converter
from nvidia.dali._autograph.pyct import anno
from nvidia.dali._autograph.pyct import parser
from nvidia.dali._autograph.pyct import qual_names
from nvidia.dali._autograph.pyct import templates
from nvidia.dali._autograph.utils import ag_logging


# TODO(mdan): Rename to FunctionCallsTransformer.


class _Function(object):
    no_root = True

    def __init__(self):
        self.context_name = None


set_trace_warned = False


class _ArgTemplateBuilder(object):
    """Constructs a tuple representing the positional arguments in a call.

    Example (yes, it's legal Python 3):

        f(*args1, b, *args2, c, d)  ->  args1 + (b,) + args2 + (c, d)
    """

    def __init__(self):
        self._arg_accumulator = []
        self._argspec = []
        self._finalized = False

    def _consume_args(self):
        if self._arg_accumulator:
            self._argspec.append(gast.Tuple(elts=self._arg_accumulator, ctx=gast.Load()))
            self._arg_accumulator = []

    def add_arg(self, a):
        self._arg_accumulator.append(a)

    def add_stararg(self, a):
        self._consume_args()
        self._argspec.append(
            gast.Call(
                gast.Name("tuple", ctx=gast.Load(), annotation=None, type_comment=None),
                args=[a],
                keywords=(),
            )
        )

    def finalize(self):
        self._consume_args()
        self._finalized = True

    def to_ast(self):
        assert self._finalized
        if self._argspec:
            result = self._argspec[0]
            for i in range(1, len(self._argspec)):
                result = gast.BinOp(result, gast.Add(), self._argspec[i])
            return result
        return gast.Tuple([], gast.Load())


class CallTreeTransformer(converter.Base):
    """Transforms the call tree by renaming transformed symbols."""

    def visit_Lambda(self, node):
        if not anno.hasanno(node, "function_context_name"):
            # Lambda functions created during the conversion process have no
            # context manager.
            return self.generic_visit(node)
        with self.state[_Function] as fn_scope:
            fn_scope.context_name = anno.getanno(node, "function_context_name")
            return self.generic_visit(node)

    def visit_FunctionDef(self, node):
        # Decorators and arg defaults are part of the outer scope.
        node.decorator_list = self.visit_block(node.decorator_list)
        node.args.defaults = self.visit_block(node.args.defaults)
        for i, d in enumerate(node.args.kw_defaults):
            if d is not None:
                node.args.kw_defaults[i] = self.visit(d)
        with self.state[_Function] as fn_scope:
            # Note: if the conversion process ever creates helper functions, this
            # assumption will no longer hold.
            assert anno.hasanno(
                node, "function_context_name"
            ), "The function_scopes converter always creates a scope for functions."
            fn_scope.context_name = anno.getanno(node, "function_context_name")
            node.body = self.visit_block(node.body)
            if node.returns:
                node.returns = self.visit(node.returns)
            return node

    def visit_With(self, node):
        # Context manager calls (in node.items) are not converted.
        node.body = self.visit_block(node.body)
        return node

    def _args_to_tuple(self, node):
        """Ties together all positional and *arg arguments in a single tuple."""
        # TODO(mdan): We could rewrite this to just a call to tuple(). Maybe better?
        # For example for
        #   f(a, b, *args)
        # instead of writing:
        #   (a, b) + args
        # just write this?
        #   tuple(a, b, *args)
        builder = _ArgTemplateBuilder()
        for a in node.args:
            if isinstance(a, gast.Starred):
                builder.add_stararg(a.value)
            else:
                builder.add_arg(a)
        builder.finalize()
        return builder.to_ast()

    def _kwargs_to_dict(self, node):
        """Ties together all keyword and **kwarg arguments in a single dict."""
        if node.keywords:
            return gast.Call(
                gast.Name("dict", ctx=gast.Load(), annotation=None, type_comment=None),
                args=(),
                keywords=node.keywords,
            )
        else:
            return parser.parse_expression("None")

    def visit_Call(self, node):
        full_name = str(anno.getanno(node.func, anno.Basic.QN, default=""))
        function_context_name = self.state[_Function].context_name
        node = self.generic_visit(node)

        # TODO(mdan): Refactor converted_call as a 'Call' operator.

        # Calls to the internal 'ag__' module are never converted (though their
        # arguments might be).
        if full_name.startswith("ag__."):
            return node

        # Calls to the function context manager (inserted by function_scopes) are
        # also safe.
        if full_name.startswith(function_context_name + "."):
            return node

        # Calls to pdb.set_trace or ipdb.set_trace are never converted. We don't use
        # the normal mechanisms to bypass these literals because they are sensitive
        # to the frame they are being called from.
        # TODO(mdan): Generalize this to a "static allowlist" config.
        if full_name in ("pdb.set_trace", "ipdb.set_trace", "breakpoint"):
            global set_trace_warned
            if not set_trace_warned:
                # TODO(klecki): Point to a DALI-specific documentation here.
                ag_logging.warning(
                    "Detected `pdb.set_trace()` in user code. The code"
                    " generated by AutoGraph is not optimized for step-by-step"
                    " debugging."
                )
                set_trace_warned = True
            return node

        if full_name == "print" and not self.ctx.user.options.uses(
            converter.Feature.BUILTIN_FUNCTIONS
        ):
            return node

        template = """
      ag__.converted_call(func, args, kwargs, function_ctx)
    """
        new_call = templates.replace_as_expression(
            template,
            func=node.func,
            args=self._args_to_tuple(node),
            kwargs=self._kwargs_to_dict(node),
            function_ctx=function_context_name,
        )

        return new_call


def transform(node, ctx):
    """Transform function call to the compiled counterparts.

    Args:
      node: AST
      ctx: EntityContext
    Returns:
      A tuple (node, new_names):
          node: The transformed AST
          new_names: set(string), containing any newly-generated names
    """
    node = qual_names.resolve(node)

    node = CallTreeTransformer(ctx).visit(node)
    return node
